#include "pa.cuh"
                        
#define FUNCTION_DEFINE                                 \
    void {{func_name}}(void* out_ptr,                   \
                      void* exp_sums_ptr_,              \
                      void* max_logits_ptr_,            \
                      void* tmp_out_ptr_,               \
                      void* query_ptr,                  \
                      void* key_cache_ptr,              \
                      void* value_cache_ptr,            \
                      const float scale,                \
                      int* block_tables_ptr,            \
                      int* context_lens_ptr,            \
                      const int max_context_len,        \
                      const int num_seqs,               \
                      const int num_kv_heads,           \
                      const int num_heads,              \
                      const int max_num_blocks_per_seq, \
                      const int q_stride,               \
                      const int kv_block_stride,        \
                      const int kv_head_stride,         \
                      const float* alibi_slopes_ptr,    \
                      const float* q_scale_ptr,         \
                      const float* k_scale_ptr,         \
                      const float* v_scale_ptr,         \
                      const float* fp8_out_scale_ptr,   \
                      void* stream)

extern "C" {
FUNCTION_DEFINE;
}
                        
FUNCTION_DEFINE
{
    constexpr int PARTITION_SIZE = 256;

    const int max_num_partitions =
      DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
    constexpr int head_size = {{head_size}};
    
    assert(num_heads % num_kv_heads == 0);

    float* exp_sums_ptr   = reinterpret_cast<float*>(exp_sums_ptr_);
    float* max_logits_ptr = reinterpret_cast<float*>(max_logits_ptr_);
    {{out_dtype}}* tmp_out_ptr =
        reinterpret_cast<{{out_dtype}}*>(tmp_out_ptr_);

    constexpr int NTHR = 256;
    dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
    dim3 block(NTHR);

    paged_attention_ll4mi_QKV_mfma16_kernel<{{dtype}},
                                            {{out_dtype}},                       
                                            {{kv_dtype}},
                                            {% if fp8_kv_dtype == 'auto' %}
                                            vllm::Fp8KVCacheDataType::kAuto,
                                            {% else %}
                                            vllm::Fp8KVCacheDataType::kFp8E4M3,
                                            {% endif %}                              
                                            {{block_size}},              
                                            head_size,               
                                            NTHR,
                                            {{alibi_enabled}},                    
                                            {{gqa_ratio}},
                                            {{mtp}},
                                            {{quant_method}},
                                            {{"true" if v_shuffle else "false"}}>               
    <<<grid, block, 0, reinterpret_cast<hipStream_t>(stream)>>>(reinterpret_cast<{{dtype}}*>(query_ptr),                      
                                    reinterpret_cast<{{kv_dtype}}*>(key_cache_ptr),                  
                                    reinterpret_cast<{{kv_dtype}}*>(value_cache_ptr),      
                                    scale,                          
                                    block_tables_ptr,                  
                                    context_lens_ptr,
                                    nullptr,            
                                    max_num_blocks_per_seq,          
                                    alibi_slopes_ptr,               
                                    q_stride,                       
                                    kv_block_stride,                
                                    kv_head_stride,                             
                                    exp_sums_ptr,                   
                                    max_logits_ptr,                 
                                    tmp_out_ptr,
                                    q_scale_ptr,                                               
                                    k_scale_ptr,                    
                                    v_scale_ptr);

    dim3 reduce_grid(num_heads, num_seqs, {{mtp}});
    dim3 reduce_block(head_size);
    paged_attention_ll4mi_reduce_kernel<{{out_dtype}}, {{out_dtype}}, head_size, head_size, PARTITION_SIZE, {{npar_loops}}> 
    <<<reduce_grid, reduce_block, 0, reinterpret_cast<hipStream_t>(stream)>>>(reinterpret_cast<{{out_dtype}}*>(out_ptr),                                        
                                                                              exp_sums_ptr,        
                                                                              max_logits_ptr,                                 
                                                                              tmp_out_ptr,                                   
                                                                              context_lens_ptr,
                                                                              nullptr,                                
                                                                              max_num_partitions,                                               
                                                                              fp8_out_scale_ptr);
                                    
}